import json
from tqdm import tqdm
import os
from azfuse import File


def is_question_answerable(data, assume_answerable=True):
    
    if "answerable" in data:
        return data["answerable"]
    elif "answer_type" in data:
        if data["answer_type"] == "unanswerable":
            return 0
        else:
            return 1
    elif "question_type" in data:
        if data["question_type"] == "adversarial":
            return 0
        elif data["question_type"] == "absurd":
            return 0
        else:
            return 1
    elif "remove_0" in data["question_id"]: # a hardcode for our unk questions
        return 0
    elif "category" in data:
        if data["category"] == "unk":
            return 0
        else:
            return 1
    elif assume_answerable:
        return 1
    else:
        return None


def safe_divide(a, b):
    a = float(a)
    b = float(b)
    if b == 0:
        return 0
    return a / b


def run_threshold_metric(model_id, pred_data, pred_prob_file, debug=False, overwrite=False):
    acc_output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_output.jsonl")
    recall_output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_refusal_lave_output.jsonl")
    for pred_prob_threshold in range(0, 11):
        pred_prob_threshold = pred_prob_threshold / 10.
        overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_thresh{pred_prob_threshold}_overall_result.json")
        get_overall_threshold_metrics(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, overall_results_file)
        get_overall_threshold_refusal_metrics(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, overall_results_file.replace("overall_result", "overall_refusal_result"))



def get_overall_threshold_metrics(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, overall_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }
    evaluator_acc_on_gt_refusal = 0.
    total_num_to_evaluate = 0 

    for acc_d, recall_d, pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        assert str(acc_d["question_id"]) == str(pred_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {pred_prob_d['question_id']}"
        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal_labeled = not answerable
            if int(gt_refusal) != int(gt_refusal_labeled):
                evaluator_acc_on_gt_refusal += 0
            else:
                evaluator_acc_on_gt_refusal += 1
            total_num_to_evaluate += 1
            gt_refusal = gt_refusal_labeled
        if gt_refusal == -1:
            total_num_instance["missing"] += 1
            continue
        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue
        pred_refusal = recall_d["answer_refusal"]
        pred_prob = pred_prob_d["yes_prob"]
        answer = acc_d["answer"]
        if pred_prob < pred_prob_threshold:
            # if confidence is lower than the threshold, we force it to refusal
            pred_refusal_force = 1
        else:
            pred_refusal_force = 0
        
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
            if pred_refusal == 1:
                final_acc["refusal"] += 1
                score = 1
            else:
                final_acc["refusal"] += 0
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in overall_results_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)
            final_acc["all"] += score
        else:
            total_num_instance["answer"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal

            if "labels" in acc_d:
                if pred_refusal_force == 1:
                    answer = "I don't know"
                score = acc_d["labels"].get(answer.lower(), 0)
            elif "vizwiz_val" in overall_results_file:
                if pred_refusal_force == 1:
                    answer = "unanswerable"
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(answer.lower(), 0)
            elif pred_refusal == 0:
                score = acc_d["acc"]
            else:
                score = 0
            final_acc["answer"] += score
            final_acc["all"] += score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    if total_num_to_evaluate > 0:
        eval_results["evaluator_acc_on_gt_refusal"] = safe_divide(evaluator_acc_on_gt_refusal, total_num_to_evaluate)
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(overall_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_overall_threshold_refusal_metrics(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, overall_refusl_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    total_num_instance = {
            "gt_refusal": 0,
            "false_refusal": 0,
            "false_answer": 0,
            "positive_answer": 0,
            "positive_refusal": 0,
            "gt_answer": 0,
            "pred_refusal": 0,
            "pred_answer": 0,
            "missing": 0,
            "pred_answer_partial": 0,
            "pred_refusal_partial": 0,
            "positive_answer_partial": 0,
            "positive_refusal_partial": 0,
            "false_answer_partial": 0,
            "false_refusal_partial": 0,
        }
    for acc_d, recall_d, pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert str(acc_d["question_id"]) == str(pred_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {pred_prob_d['question_id']}"
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        pred_refusal = recall_d["answer_refusal"]
        answer = acc_d["answer"]

        if gt_refusal == -1 or pred_refusal == -1:
            total_num_instance["missing"] += 1
            continue

        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue
        pred_refusal = recall_d["answer_refusal"]
        pred_prob = pred_prob_d["yes_prob"]
        if pred_prob < pred_prob_threshold:
            # if confidence is lower than the threshold, we force it to refusal
            pred_refusal_force = 1
        else:
            pred_refusal_force = 0

        # base category on gt_refusal
        if gt_refusal == 1:
            total_num_instance["gt_refusal"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
            if pred_refusal == 1:
                total_num_instance["pred_refusal"] += 1
                total_num_instance["positive_refusal"] += 1
            elif pred_refusal == 0:
                total_num_instance["false_answer"] += 1
                total_num_instance["pred_answer"] += 1
            else:
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["positive_refusal_partial"] += 1
                total_num_instance["false_answer_partial"] += 1
        #FIXME: how to handle when gt is not sure? now just treating it the same as answerable
        # elif gt_refusal == 0.5:
        #     total_num_instance["gt_answer"] += 1
        #     if pred_refusal == 0:
        #         total_num_instance["positive_answer"] += 1
        #         total_num_instance["pred_answer"] += 1
        #     elif pred_refusal == 1:
        #         total_num_instance["false_refusal"] += 1
        #         total_num_instance["pred_refusal"] += 1
        #     else:
        #         total_num_instance["pred_answer_partial"] += 1
        #         total_num_instance["pred_refusal_partial"] += 1
        #         total_num_instance["positive_answer_partial"] += 1
        #         total_num_instance["false_refusal_partial"] += 1
        else:
            total_num_instance["gt_answer"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
            if "labels" in acc_d:
                if pred_refusal_force:
                    answer = "I don't know"
                score = acc_d["labels"].get(answer.lower(), 0)
            elif "vizwiz_val" in overall_refusl_results_file:
                if pred_refusal_force:
                    answer = "unanswerable"
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(answer.lower(), 0)
            elif pred_refusal == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if pred_refusal == 0:
                total_num_instance["pred_answer"] += 1
                if score > 0:
                    total_num_instance["positive_answer"] += 1
                else:
                    total_num_instance["false_answer"] += 1
            elif pred_refusal == 1:
                total_num_instance["false_refusal"] += 1
                total_num_instance["pred_refusal"] += 1
            else:
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["positive_answer_partial"] += 1
                total_num_instance["false_refusal_partial"] += 1
    # get fp, fn, tp, tn rate for refusal
    eval_results = {}
    eval_results["refusal"] = safe_divide(total_num_instance["pred_refusal"]  , total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer"] = safe_divide(total_num_instance["pred_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal"] = safe_divide(total_num_instance["positive_refusal"], total_num_instance["pred_refusal"])
    eval_results["false_refusal"] = safe_divide(total_num_instance["false_refusal"], total_num_instance["pred_refusal"])
    
    eval_results["positive_answer"] = safe_divide(total_num_instance["positive_answer"], total_num_instance["pred_answer"])
    eval_results["false_answer"] = safe_divide(total_num_instance["false_answer"], total_num_instance["pred_answer"])
    eval_results["precision_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["recall_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["pred_refusal"] + total_num_instance["pred_answer"])
    eval_results["f1_all"] = safe_divide(2 * eval_results["precision_all"] * eval_results["recall_all"], eval_results["precision_all"] + eval_results["recall_all"])

    # consider parital for the above metrics
    eval_results["refusal_partial"] = safe_divide(total_num_instance["pred_refusal_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer_partial"] = safe_divide(total_num_instance["pred_answer_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal_partial"] = safe_divide(total_num_instance["positive_refusal_partial"],  total_num_instance["pred_refusal_partial"])
    eval_results["false_refusal_partial"] = safe_divide(total_num_instance["false_refusal_partial"], total_num_instance["pred_refusal_partial"])
    eval_results["positive_answer_partial"] = safe_divide(total_num_instance["positive_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["false_answer_partial"] = safe_divide(total_num_instance["false_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["counts"] = total_num_instance
    print(f"Final acc: {eval_results}")
                                                                                                                                                
    with File.open(overall_refusl_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_vqa_score(answers):
    # count the occurance of unique answers
    from collections import defaultdict
    answer_count = defaultdict(int)
    for answer in answers:
        answer_count[answer] += 1
    scores = defaultdict(float)
    for answer, count in answer_count.items():
        scores[answer] = min(1, count / 3.)
    return scores


def get_confidence_weighted_metrics(acc_output_file, recall_output_file, gt_prob_file, pred_prob_file, pred_prob_threshold, conf_weighted_output_file, refusal_reward=False, debug=False):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file) and File.isfile(gt_prob_file)):
    #     print(f"Output files {acc_output_file} or {gt_prob_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    gt_probs = [json.loads(el) for el in File.open(gt_prob_file, 'r')]
    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0,
        "gt_not_yes_or_no": 0,
        "pred_not_yes_or_no": 0,
        "coverage": 0,
        "risk": 0,
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }

    for acc_d, recall_d , gt_prob_d, pred_prob_d in tqdm(zip(acc_output, recall_output, gt_probs, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        assert str(acc_d["question_id"]) == str(gt_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {gt_prob_d['question_id']}"
        if acc_d["question"].replace("<image>\n", "").strip() != gt_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {gt_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == gt_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {gt_prob_d['question']}"
        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert str(acc_d["question_id"]) == str(pred_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {pred_prob_d['question_id']}"
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        pred_refusal = recall_d["answer_refusal"]
        pred_prob = pred_prob_d["yes_prob"]
        answer = acc_d["answer"]
        if pred_prob < pred_prob_threshold:
            # if confidence is lower than the threshold, we force it to refusal
            pred_refusal_force = 1
        else:
            pred_refusal_force = 0
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
            if pred_refusal == 1:
                score = 1
            else:
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in conf_weighted_output_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)

        else:
            total_num_instance["answer"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal

            if "labels" in acc_d:
                if pred_refusal_force:
                    answer = "I don't know"
                score = acc_d["labels"].get(answer.lower(), 0)
            elif "vizwiz_val" in conf_weighted_output_file:
                if pred_refusal_force:
                    answer = "unanswerable"
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(answer.lower(), 0)
            elif pred_refusal == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if pred_refusal == 0:
                total_num_instance["coverage"] += 1
            curr_risk = (1 - score) * int(pred_refusal == 0)
            total_num_instance["risk"] += curr_risk

        if gt_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["gt_not_yes_or_no"] += 1
        if pred_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["pred_not_yes_or_no"] += 1
        conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] - (score == 0) * gt_prob_d["no_prob"]
        if refusal_reward:
            conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] + 0 * (recall_d["answer_refusal"] == 1) - (score == 0 and recall_d["answer_refusal"] < 1) * gt_prob_d["no_prob"]
        if debug:
            print("=============================================================")
            print(f"Question: {acc_d['question']}")
            print(f"Refs: {acc_d['gt']}")
            print(f"Pred: {acc_d['answer']}")
            print(f"GT_refusal: {gt_refusal}")
            print(f"Pred_refusal: {recall_d['answer_refusal']}")
            print(f"Score: {score}")
            print(f"Gt_yes_prob: {gt_prob_d['yes_prob']}")
            print(f"Conf weighted score: {conf_weighted_score}")
            print("=============================================================")
        if gt_refusal == 1:
            final_acc["refusal"] += conf_weighted_score
        else:
            final_acc["answer"] += conf_weighted_score
        final_acc["all"] += conf_weighted_score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["coverage"] = safe_divide(total_num_instance["coverage"], total_num_instance["answer"])
    eval_results["risk"] = safe_divide(total_num_instance["risk"], total_num_instance["coverage"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    eval_results["gt_not_yes_or_no"] = safe_divide(total_num_instance["gt_not_yes_or_no"], total_num_instance["all"])
    eval_results["pred_not_yes_or_no"] = safe_divide(total_num_instance["pred_not_yes_or_no"], total_num_instance["all"])
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(conf_weighted_output_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_pred_confidence_weighted_metrics(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, conf_weighted_output_file, refusal_reward=False, debug=False):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file) and File.isfile(gt_prob_file)):
    #     print(f"Output files {acc_output_file} or {gt_prob_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0,
        "pred_not_yes_or_no": 0,
        "coverage": 0,
        "risk": 0,
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }

    for acc_d, recall_d , pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        assert str(acc_d["question_id"]) == str(pred_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {pred_prob_d['question_id']}"
        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        pred_refusal = recall_d["answer_refusal"]
        pred_prob = pred_prob_d["yes_prob"]
        answer = acc_d["answer"]
        if pred_prob < pred_prob_threshold:
            # if confidence is lower than the threshold, we force it to refusal
            pred_refusal_force = 1
        else:
            pred_refusal_force = 0
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
            if pred_refusal == 1:
                score = 1
            else:
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in conf_weighted_output_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)

        else:
            total_num_instance["answer"] += 1
            pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal

            if "labels" in acc_d:
                if pred_refusal_force:
                    answer = "I don't know"
                score = acc_d["labels"].get(answer.lower(), 0)
            elif "vizwiz_val" in conf_weighted_output_file:
                if pred_refusal_force:
                    answer = "unanswerable"
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(answer.lower(), 0)
            elif pred_refusal == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if pred_refusal == 0:
                total_num_instance["coverage"] += 1
            curr_risk = (1 - score) * int(pred_refusal == 0)
            total_num_instance["risk"] += curr_risk

        if pred_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["pred_not_yes_or_no"] += 1
        conf_weighted_score = (score > 0)* score * pred_prob_d["yes_prob"] - (score == 0) * pred_prob_d["yes_prob"]
        if refusal_reward:
            conf_weighted_score = (score > 0)* score * pred_prob_d["yes_prob"] + 0 * (recall_d["answer_refusal"] == 1) - (score == 0 and recall_d["answer_refusal"] < 1) * pred_prob_d["yes_prob"]
        if debug:
            print("=============================================================")
            print(f"Question: {acc_d['question']}")
            print(f"Refs: {acc_d['gt']}")
            print(f"Pred: {acc_d['answer']}")
            print(f"GT_refusal: {gt_refusal}")
            print(f"Pred_refusal: {recall_d['answer_refusal']}")
            print(f"Score: {score}")
            print(f"Pred_yes_prob: {pred_prob_d['yes_prob']}")
            print(f"Conf weighted score: {conf_weighted_score}")
            print("=============================================================")
        if gt_refusal == 1:
            final_acc["refusal"] += conf_weighted_score
        else:
            final_acc["answer"] += conf_weighted_score
        final_acc["all"] += conf_weighted_score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["coverage"] = safe_divide(total_num_instance["coverage"], total_num_instance["answer"])
    eval_results["risk"] = safe_divide(total_num_instance["risk"], total_num_instance["coverage"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    eval_results["pred_not_yes_or_no"] = safe_divide(total_num_instance["pred_not_yes_or_no"], total_num_instance["all"])
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(conf_weighted_output_file, "w") as f:
        json.dump(eval_results, f)
    return


def calibration_curve_data(correctness, confidence_scores):
    import numpy as np
    correctness = np.array(correctness)
    confidence_scores = np.array(confidence_scores)
    # Binning the data
    bins = np.linspace(0, 1, 11)  # Create 10 bins
    digitized = np.digitize(confidence_scores, bins)  # Assign each score to a bin
    all_bins_correct = []
    all_bins = []

    # Calculate the mean confidence and the mean correctness in each bin
    for i in range(1, len(bins)):
        curr_bin_correct = correctness[digitized == i]
        curr_bin_confidence = confidence_scores[digitized == i]
        all_bins_correct.append(curr_bin_correct)
        all_bins.append(curr_bin_confidence)

    return all_bins_correct, all_bins


def brier_score(correctness, confidence_scores):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    return np.mean((np.array(confidence_scores) - np.array(correctness)) ** 2)


def expected_calibration_error(all_bins_correct, all_bins):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    ece = 0
    total = 0
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        curr_ece = 0
        total += len(curr_bin_correct)
        if len(curr_bin_correct) and len(curr_bin_confidence):
            curr_ece = np.abs(curr_bin_confidence.mean() - curr_bin_correct.mean())
        curr_ece = curr_ece * len(curr_bin_correct) 
        ece += curr_ece
    if total == 0:
        return 0
    return ece / total


def max_calibration_error(all_bins_correct, all_bins):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    ece = 0
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        curr_ece = 0
        for correct, confidence in zip(curr_bin_correct, curr_bin_confidence):
            curr_ece = np.abs(correct - confidence)
        ece = max(ece, curr_ece)
    return ece

    # return np.mean((np.array(bin_means) - np.array(bin_correct_rate)) ** 2)


def get_calibration_fig(acc_output_file, recall_output_file, pred_prob_file, pred_prob_threshold, calibration_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]

    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    answerable_confidence_scores = []
    answerable_correctness = []
    all_correctness = []
    all_confidence_scores = []
    unk_correctness, unk_confidence_scores = [], []

    for acc_d, recall_d , pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        gt_refusal = recall_d["gt_refusal"]


        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            # total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal_labeled = not answerable
            gt_refusal = gt_refusal_labeled
        pred_refusal = recall_d["answer_refusal"]
        pred_prob = pred_prob_d["yes_prob"]
        answer = acc_d["answer"]
        if pred_prob < pred_prob_threshold:
            # if confidence is lower than the threshold, we force it to refusal
            pred_refusal_force = 1
        else:
            pred_refusal_force = 0
        if gt_refusal == -1:
            continue
        if acc_d["acc"] == -1:
            continue

        pred_refusal = pred_refusal_force if pred_refusal_force else pred_refusal
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            score = 1
            if pred_refusal == 1:
                score = score
            else:
                score = 0
            all_confidence_scores.append(pred_prob)
            all_correctness.append(score)
            unk_confidence_scores.append(pred_prob)
            unk_correctness.append(score)
        else:

            if "labels" in acc_d:
                if pred_refusal_force:
                    answer = "I don't know"
                score = acc_d["labels"].get(answer.lower(), 0)
            elif "vizwiz_val" in calibration_results_file:
                if pred_refusal_force:
                    answer = "unanswerable"
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(answer.lower(), 0)
            elif pred_refusal == 0:
                score = acc_d["acc"]
            else:
                score = 0
            answerable_confidence_scores.append(pred_prob)
            answerable_correctness.append(score)
            all_confidence_scores.append(pred_prob)
            all_correctness.append(score)
    bin_means = []
    bin_correct_rate = []
    all_bins_correct, all_bins = calibration_curve_data(answerable_correctness, answerable_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("calibration_curve.png", "calibration_score.json")
    score = brier_score(answerable_correctness, answerable_confidence_scores)
    with File.open(output_file, "w") as f:
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
    score_ece = expected_calibration_error(all_bins_correct, all_bins)
    score_max = max_calibration_error(all_bins_correct, all_bins)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./threshold_calibration_curve.png", format="png")
    plt.close()
    with File.open(calibration_results_file, "wb") as f:
        content = File.open("./threshold_calibration_curve.png", "rb").read()
        f.write(content)
    

    # all correctness    
    bin_means = []
    bin_correct_rate = []
    all_bins_correct, all_bins = calibration_curve_data(all_correctness, all_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_score.json")
    score = brier_score(all_correctness, all_confidence_scores)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_score_expected_max.json")
    score_ece = expected_calibration_error(all_bins_correct, all_bins)
    score_max = max_calibration_error(all_bins_correct, all_bins)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./threshold_all_calibration_curve.png", format="png")
    plt.close()
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_curve.png")
    with File.open(output_file, "wb") as f:
        content = File.open("./threshold_all_calibration_curve.png", "rb").read()
        f.write(content)
    

    # all correctness    
    bin_means = []
    bin_correct_rate = []
    all_bins_correct, all_bins = calibration_curve_data(unk_correctness, unk_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_score.json")
    score = brier_score(unk_correctness, unk_confidence_scores)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_score_expected_max.json")
    score_ece = expected_calibration_error(all_bins_correct, all_bins)
    score_max = max_calibration_error(all_bins_correct, all_bins)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./threshold_unk_calibration_curve.png", format="png")
    plt.close()
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_curve.png")
    with File.open(output_file, "wb") as f:
        content = File.open("./threshold_unk_calibration_curve.png", "rb").read()
        f.write(content)

    


def main():
    from fire import Fire
    Fire()

if __name__ == '__main__':
    main()
